{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/curie/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "#encoding:utf8\n",
    "from __future__ import print_function\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pylab as plt\n",
    "\n",
    "import keras\n",
    "from keras.datasets import cifar10\n",
    "from keras.datasets import mnist\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense,Activation,Flatten,Dropout\n",
    "from keras.layers import Conv2D,MaxPool2D\n",
    "from keras.optimizers import RMSprop,SGD\n",
    "from keras.models import load_model\n",
    "from keras.models import model_from_json\n",
    "import scipy.io as sio\n",
    "import tensorflow as tf\n",
    "\n",
    "from flask import Flask, render_template, request\n",
    "from scipy.misc import imsave, imread, imresize\n",
    "import re\n",
    "import sys\n",
    "import os\n",
    "from keras.preprocessing import image\n",
    "from keras.applications.vgg16 import VGG16\n",
    "from keras.applications.resnet50 import ResNet50\n",
    "from keras.applications import xception\n",
    "from keras.applications import inception_v3\n",
    "from keras.applications.vgg16 import preprocess_input, decode_predictions\n",
    "from keras.applications.resnet50 import preprocess_input, decode_predictions\n",
    "from getdata import load\n",
    "from keras.models import Model\n",
    "from keras.layers import Input, Dense\n",
    "from keras import backend as K\n",
    "# K.set_image_dim_ordering('th')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_img(img_path):\n",
    "    img = image.load_img(img_path,target_size=(244,244))\n",
    "    x = image.img_to_array(img)\n",
    "    x = np.expand_dims(x,axis = 0)\n",
    "    x = preprocess_input(x)\n",
    "    return x\n",
    "# def predict_top_labels(preds,top_num):\n",
    "#     print('Predicted:', decode_predictions(preds, top=top_num)[0])\n",
    "#     return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# base_model = ResNet50(weights='imagenet')\n",
    "base_model = ResNet50(weights='imagenet',include_top=False,input_shape=(244,244,3))\n",
    "# base_model = ResNet50(include_top=False, weights=None, input_shape=(224, 224, 3))\n",
    "for layer in base_model.layers:\n",
    "    layer.trainable = False\n",
    "# model = Sequential()\n",
    "# model.add(base_model)\n",
    "# model.add(Flatten())\n",
    "# model.add(Dense(256,activation='relu'))\n",
    "# model.add(Dense(input_dim=1, activation='sigmoid',\n",
    "#             bias_initializer='normal', units=5)) \n",
    "# base_model_out = base_model.output\n",
    "# my_layer = Flatten()(base_model_out)\n",
    "# my_layer = Dense(256,activation='relu')(my_layer)\n",
    "# model = Dense(input_dim=1, activation='sigmoid',\n",
    "#             bias_initializer='normal', units=5)(my_layer) \n",
    "\n",
    "# rms = RMSprop(lr=0.001, rho=0.9, epsilon=1e-08, decay=0.0)\n",
    "# # learning rate should be low, as Resnet is ok\n",
    "# # model.compile(optimizer=rms, loss='categorical_crossentropy')\n",
    "\n",
    "# # model.compile(optimizer='rmsprop',\n",
    "# #http://www.datalearner.com/blog/1051521451493989\n",
    "# #在多标签分类中，大多使用binary_crossentropy损失而不是通常在多类分类中使用\n",
    "# # 的categorical_crossentropy损失函数。这可能看起来不合理，但因为每个输出节点都是独立的\n",
    "# # ，选择二元损失，并将网络输出建模为每个标签独立的bernoulli分布。\n",
    "# model.compile(optimizer=rms ,\n",
    "#             loss='binary_crossentropy',\n",
    "#             metrics=['accuracy'])\n",
    "\n",
    "# model.summary()\n",
    "\n",
    "# from keras.utils import plot_model\n",
    "# plot_model(model, to_file='model.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_1 (InputLayer)            (None, 244, 244, 3)  0                                            \n",
      "__________________________________________________________________________________________________\n",
      "conv1_pad (ZeroPadding2D)       (None, 250, 250, 3)  0           input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "conv1 (Conv2D)                  (None, 122, 122, 64) 9472        conv1_pad[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "bn_conv1 (BatchNormalization)   (None, 122, 122, 64) 256         conv1[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "activation_1 (Activation)       (None, 122, 122, 64) 0           bn_conv1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2D)  (None, 60, 60, 64)   0           activation_1[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "res2a_branch2a (Conv2D)         (None, 60, 60, 64)   4160        max_pooling2d_1[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "bn2a_branch2a (BatchNormalizati (None, 60, 60, 64)   256         res2a_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_2 (Activation)       (None, 60, 60, 64)   0           bn2a_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res2a_branch2b (Conv2D)         (None, 60, 60, 64)   36928       activation_2[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "bn2a_branch2b (BatchNormalizati (None, 60, 60, 64)   256         res2a_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_3 (Activation)       (None, 60, 60, 64)   0           bn2a_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res2a_branch2c (Conv2D)         (None, 60, 60, 256)  16640       activation_3[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "res2a_branch1 (Conv2D)          (None, 60, 60, 256)  16640       max_pooling2d_1[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "bn2a_branch2c (BatchNormalizati (None, 60, 60, 256)  1024        res2a_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "bn2a_branch1 (BatchNormalizatio (None, 60, 60, 256)  1024        res2a_branch1[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "add_1 (Add)                     (None, 60, 60, 256)  0           bn2a_branch2c[0][0]              \n",
      "                                                                 bn2a_branch1[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "activation_4 (Activation)       (None, 60, 60, 256)  0           add_1[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "res2b_branch2a (Conv2D)         (None, 60, 60, 64)   16448       activation_4[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "bn2b_branch2a (BatchNormalizati (None, 60, 60, 64)   256         res2b_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_5 (Activation)       (None, 60, 60, 64)   0           bn2b_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res2b_branch2b (Conv2D)         (None, 60, 60, 64)   36928       activation_5[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "bn2b_branch2b (BatchNormalizati (None, 60, 60, 64)   256         res2b_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_6 (Activation)       (None, 60, 60, 64)   0           bn2b_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res2b_branch2c (Conv2D)         (None, 60, 60, 256)  16640       activation_6[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "bn2b_branch2c (BatchNormalizati (None, 60, 60, 256)  1024        res2b_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "add_2 (Add)                     (None, 60, 60, 256)  0           bn2b_branch2c[0][0]              \n",
      "                                                                 activation_4[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "activation_7 (Activation)       (None, 60, 60, 256)  0           add_2[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "res2c_branch2a (Conv2D)         (None, 60, 60, 64)   16448       activation_7[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "bn2c_branch2a (BatchNormalizati (None, 60, 60, 64)   256         res2c_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_8 (Activation)       (None, 60, 60, 64)   0           bn2c_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res2c_branch2b (Conv2D)         (None, 60, 60, 64)   36928       activation_8[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "bn2c_branch2b (BatchNormalizati (None, 60, 60, 64)   256         res2c_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_9 (Activation)       (None, 60, 60, 64)   0           bn2c_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res2c_branch2c (Conv2D)         (None, 60, 60, 256)  16640       activation_9[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "bn2c_branch2c (BatchNormalizati (None, 60, 60, 256)  1024        res2c_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "add_3 (Add)                     (None, 60, 60, 256)  0           bn2c_branch2c[0][0]              \n",
      "                                                                 activation_7[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "activation_10 (Activation)      (None, 60, 60, 256)  0           add_3[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "res3a_branch2a (Conv2D)         (None, 30, 30, 128)  32896       activation_10[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn3a_branch2a (BatchNormalizati (None, 30, 30, 128)  512         res3a_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_11 (Activation)      (None, 30, 30, 128)  0           bn3a_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res3a_branch2b (Conv2D)         (None, 30, 30, 128)  147584      activation_11[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn3a_branch2b (BatchNormalizati (None, 30, 30, 128)  512         res3a_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_12 (Activation)      (None, 30, 30, 128)  0           bn3a_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res3a_branch2c (Conv2D)         (None, 30, 30, 512)  66048       activation_12[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res3a_branch1 (Conv2D)          (None, 30, 30, 512)  131584      activation_10[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn3a_branch2c (BatchNormalizati (None, 30, 30, 512)  2048        res3a_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "bn3a_branch1 (BatchNormalizatio (None, 30, 30, 512)  2048        res3a_branch1[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "add_4 (Add)                     (None, 30, 30, 512)  0           bn3a_branch2c[0][0]              \n",
      "                                                                 bn3a_branch1[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "activation_13 (Activation)      (None, 30, 30, 512)  0           add_4[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "res3b_branch2a (Conv2D)         (None, 30, 30, 128)  65664       activation_13[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn3b_branch2a (BatchNormalizati (None, 30, 30, 128)  512         res3b_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_14 (Activation)      (None, 30, 30, 128)  0           bn3b_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res3b_branch2b (Conv2D)         (None, 30, 30, 128)  147584      activation_14[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn3b_branch2b (BatchNormalizati (None, 30, 30, 128)  512         res3b_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_15 (Activation)      (None, 30, 30, 128)  0           bn3b_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res3b_branch2c (Conv2D)         (None, 30, 30, 512)  66048       activation_15[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn3b_branch2c (BatchNormalizati (None, 30, 30, 512)  2048        res3b_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "add_5 (Add)                     (None, 30, 30, 512)  0           bn3b_branch2c[0][0]              \n",
      "                                                                 activation_13[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "activation_16 (Activation)      (None, 30, 30, 512)  0           add_5[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "res3c_branch2a (Conv2D)         (None, 30, 30, 128)  65664       activation_16[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn3c_branch2a (BatchNormalizati (None, 30, 30, 128)  512         res3c_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_17 (Activation)      (None, 30, 30, 128)  0           bn3c_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res3c_branch2b (Conv2D)         (None, 30, 30, 128)  147584      activation_17[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn3c_branch2b (BatchNormalizati (None, 30, 30, 128)  512         res3c_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_18 (Activation)      (None, 30, 30, 128)  0           bn3c_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res3c_branch2c (Conv2D)         (None, 30, 30, 512)  66048       activation_18[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn3c_branch2c (BatchNormalizati (None, 30, 30, 512)  2048        res3c_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "add_6 (Add)                     (None, 30, 30, 512)  0           bn3c_branch2c[0][0]              \n",
      "                                                                 activation_16[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "activation_19 (Activation)      (None, 30, 30, 512)  0           add_6[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "res3d_branch2a (Conv2D)         (None, 30, 30, 128)  65664       activation_19[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn3d_branch2a (BatchNormalizati (None, 30, 30, 128)  512         res3d_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_20 (Activation)      (None, 30, 30, 128)  0           bn3d_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res3d_branch2b (Conv2D)         (None, 30, 30, 128)  147584      activation_20[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn3d_branch2b (BatchNormalizati (None, 30, 30, 128)  512         res3d_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_21 (Activation)      (None, 30, 30, 128)  0           bn3d_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res3d_branch2c (Conv2D)         (None, 30, 30, 512)  66048       activation_21[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn3d_branch2c (BatchNormalizati (None, 30, 30, 512)  2048        res3d_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "add_7 (Add)                     (None, 30, 30, 512)  0           bn3d_branch2c[0][0]              \n",
      "                                                                 activation_19[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "activation_22 (Activation)      (None, 30, 30, 512)  0           add_7[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "res4a_branch2a (Conv2D)         (None, 15, 15, 256)  131328      activation_22[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4a_branch2a (BatchNormalizati (None, 15, 15, 256)  1024        res4a_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_23 (Activation)      (None, 15, 15, 256)  0           bn4a_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res4a_branch2b (Conv2D)         (None, 15, 15, 256)  590080      activation_23[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4a_branch2b (BatchNormalizati (None, 15, 15, 256)  1024        res4a_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_24 (Activation)      (None, 15, 15, 256)  0           bn4a_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res4a_branch2c (Conv2D)         (None, 15, 15, 1024) 263168      activation_24[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res4a_branch1 (Conv2D)          (None, 15, 15, 1024) 525312      activation_22[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4a_branch2c (BatchNormalizati (None, 15, 15, 1024) 4096        res4a_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "bn4a_branch1 (BatchNormalizatio (None, 15, 15, 1024) 4096        res4a_branch1[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "add_8 (Add)                     (None, 15, 15, 1024) 0           bn4a_branch2c[0][0]              \n",
      "                                                                 bn4a_branch1[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "activation_25 (Activation)      (None, 15, 15, 1024) 0           add_8[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "res4b_branch2a (Conv2D)         (None, 15, 15, 256)  262400      activation_25[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4b_branch2a (BatchNormalizati (None, 15, 15, 256)  1024        res4b_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_26 (Activation)      (None, 15, 15, 256)  0           bn4b_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res4b_branch2b (Conv2D)         (None, 15, 15, 256)  590080      activation_26[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4b_branch2b (BatchNormalizati (None, 15, 15, 256)  1024        res4b_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_27 (Activation)      (None, 15, 15, 256)  0           bn4b_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res4b_branch2c (Conv2D)         (None, 15, 15, 1024) 263168      activation_27[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4b_branch2c (BatchNormalizati (None, 15, 15, 1024) 4096        res4b_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "add_9 (Add)                     (None, 15, 15, 1024) 0           bn4b_branch2c[0][0]              \n",
      "                                                                 activation_25[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "activation_28 (Activation)      (None, 15, 15, 1024) 0           add_9[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "res4c_branch2a (Conv2D)         (None, 15, 15, 256)  262400      activation_28[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4c_branch2a (BatchNormalizati (None, 15, 15, 256)  1024        res4c_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_29 (Activation)      (None, 15, 15, 256)  0           bn4c_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res4c_branch2b (Conv2D)         (None, 15, 15, 256)  590080      activation_29[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4c_branch2b (BatchNormalizati (None, 15, 15, 256)  1024        res4c_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_30 (Activation)      (None, 15, 15, 256)  0           bn4c_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res4c_branch2c (Conv2D)         (None, 15, 15, 1024) 263168      activation_30[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4c_branch2c (BatchNormalizati (None, 15, 15, 1024) 4096        res4c_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "add_10 (Add)                    (None, 15, 15, 1024) 0           bn4c_branch2c[0][0]              \n",
      "                                                                 activation_28[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "activation_31 (Activation)      (None, 15, 15, 1024) 0           add_10[0][0]                     \n",
      "__________________________________________________________________________________________________\n",
      "res4d_branch2a (Conv2D)         (None, 15, 15, 256)  262400      activation_31[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4d_branch2a (BatchNormalizati (None, 15, 15, 256)  1024        res4d_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_32 (Activation)      (None, 15, 15, 256)  0           bn4d_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res4d_branch2b (Conv2D)         (None, 15, 15, 256)  590080      activation_32[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4d_branch2b (BatchNormalizati (None, 15, 15, 256)  1024        res4d_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_33 (Activation)      (None, 15, 15, 256)  0           bn4d_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res4d_branch2c (Conv2D)         (None, 15, 15, 1024) 263168      activation_33[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4d_branch2c (BatchNormalizati (None, 15, 15, 1024) 4096        res4d_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "add_11 (Add)                    (None, 15, 15, 1024) 0           bn4d_branch2c[0][0]              \n",
      "                                                                 activation_31[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "activation_34 (Activation)      (None, 15, 15, 1024) 0           add_11[0][0]                     \n",
      "__________________________________________________________________________________________________\n",
      "res4e_branch2a (Conv2D)         (None, 15, 15, 256)  262400      activation_34[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4e_branch2a (BatchNormalizati (None, 15, 15, 256)  1024        res4e_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_35 (Activation)      (None, 15, 15, 256)  0           bn4e_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res4e_branch2b (Conv2D)         (None, 15, 15, 256)  590080      activation_35[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4e_branch2b (BatchNormalizati (None, 15, 15, 256)  1024        res4e_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_36 (Activation)      (None, 15, 15, 256)  0           bn4e_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res4e_branch2c (Conv2D)         (None, 15, 15, 1024) 263168      activation_36[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4e_branch2c (BatchNormalizati (None, 15, 15, 1024) 4096        res4e_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "add_12 (Add)                    (None, 15, 15, 1024) 0           bn4e_branch2c[0][0]              \n",
      "                                                                 activation_34[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "activation_37 (Activation)      (None, 15, 15, 1024) 0           add_12[0][0]                     \n",
      "__________________________________________________________________________________________________\n",
      "res4f_branch2a (Conv2D)         (None, 15, 15, 256)  262400      activation_37[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4f_branch2a (BatchNormalizati (None, 15, 15, 256)  1024        res4f_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_38 (Activation)      (None, 15, 15, 256)  0           bn4f_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res4f_branch2b (Conv2D)         (None, 15, 15, 256)  590080      activation_38[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4f_branch2b (BatchNormalizati (None, 15, 15, 256)  1024        res4f_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_39 (Activation)      (None, 15, 15, 256)  0           bn4f_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res4f_branch2c (Conv2D)         (None, 15, 15, 1024) 263168      activation_39[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn4f_branch2c (BatchNormalizati (None, 15, 15, 1024) 4096        res4f_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "add_13 (Add)                    (None, 15, 15, 1024) 0           bn4f_branch2c[0][0]              \n",
      "                                                                 activation_37[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "activation_40 (Activation)      (None, 15, 15, 1024) 0           add_13[0][0]                     \n",
      "__________________________________________________________________________________________________\n",
      "res5a_branch2a (Conv2D)         (None, 8, 8, 512)    524800      activation_40[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn5a_branch2a (BatchNormalizati (None, 8, 8, 512)    2048        res5a_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_41 (Activation)      (None, 8, 8, 512)    0           bn5a_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res5a_branch2b (Conv2D)         (None, 8, 8, 512)    2359808     activation_41[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn5a_branch2b (BatchNormalizati (None, 8, 8, 512)    2048        res5a_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_42 (Activation)      (None, 8, 8, 512)    0           bn5a_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res5a_branch2c (Conv2D)         (None, 8, 8, 2048)   1050624     activation_42[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res5a_branch1 (Conv2D)          (None, 8, 8, 2048)   2099200     activation_40[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn5a_branch2c (BatchNormalizati (None, 8, 8, 2048)   8192        res5a_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "bn5a_branch1 (BatchNormalizatio (None, 8, 8, 2048)   8192        res5a_branch1[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "add_14 (Add)                    (None, 8, 8, 2048)   0           bn5a_branch2c[0][0]              \n",
      "                                                                 bn5a_branch1[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "activation_43 (Activation)      (None, 8, 8, 2048)   0           add_14[0][0]                     \n",
      "__________________________________________________________________________________________________\n",
      "res5b_branch2a (Conv2D)         (None, 8, 8, 512)    1049088     activation_43[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn5b_branch2a (BatchNormalizati (None, 8, 8, 512)    2048        res5b_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_44 (Activation)      (None, 8, 8, 512)    0           bn5b_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res5b_branch2b (Conv2D)         (None, 8, 8, 512)    2359808     activation_44[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn5b_branch2b (BatchNormalizati (None, 8, 8, 512)    2048        res5b_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_45 (Activation)      (None, 8, 8, 512)    0           bn5b_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res5b_branch2c (Conv2D)         (None, 8, 8, 2048)   1050624     activation_45[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn5b_branch2c (BatchNormalizati (None, 8, 8, 2048)   8192        res5b_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "add_15 (Add)                    (None, 8, 8, 2048)   0           bn5b_branch2c[0][0]              \n",
      "                                                                 activation_43[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "activation_46 (Activation)      (None, 8, 8, 2048)   0           add_15[0][0]                     \n",
      "__________________________________________________________________________________________________\n",
      "res5c_branch2a (Conv2D)         (None, 8, 8, 512)    1049088     activation_46[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn5c_branch2a (BatchNormalizati (None, 8, 8, 512)    2048        res5c_branch2a[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_47 (Activation)      (None, 8, 8, 512)    0           bn5c_branch2a[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res5c_branch2b (Conv2D)         (None, 8, 8, 512)    2359808     activation_47[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn5c_branch2b (BatchNormalizati (None, 8, 8, 512)    2048        res5c_branch2b[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "activation_48 (Activation)      (None, 8, 8, 512)    0           bn5c_branch2b[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "res5c_branch2c (Conv2D)         (None, 8, 8, 2048)   1050624     activation_48[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "bn5c_branch2c (BatchNormalizati (None, 8, 8, 2048)   8192        res5c_branch2c[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "add_16 (Add)                    (None, 8, 8, 2048)   0           bn5c_branch2c[0][0]              \n",
      "                                                                 activation_46[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "activation_49 (Activation)      (None, 8, 8, 2048)   0           add_16[0][0]                     \n",
      "__________________________________________________________________________________________________\n",
      "avg_pool (AveragePooling2D)     (None, 1, 1, 2048)   0           activation_49[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "my_flatten (Flatten)            (None, 2048)         0           avg_pool[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 256)          524544      my_flatten[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "dense_2 (Dense)                 (None, 5)            1285        dense_1[0][0]                    \n",
      "==================================================================================================\n",
      "Total params: 24,113,541\n",
      "Trainable params: 525,829\n",
      "Non-trainable params: 23,587,712\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "base_model_out = base_model.output\n",
    "# my_layer = base_model.add(Flatten())\n",
    "base_model_out=Flatten(name='my_flatten')(base_model_out)\n",
    "my_layer = Dense(256,activation='relu')(base_model_out)\n",
    "predictions = Dense(input_dim=1, activation='sigmoid',\n",
    "            bias_initializer='normal', units=5)(my_layer) \n",
    "\n",
    "rms = RMSprop(lr=0.001, rho=0.9, epsilon=1e-08, decay=0.0)\n",
    "# learning rate should be low, as Resnet is ok\n",
    "# model.compile(optimizer=rms, loss='categorical_crossentropy')\n",
    "\n",
    "# model.compile(optimizer='rmsprop',\n",
    "#http://www.datalearner.com/blog/1051521451493989\n",
    "#在多标签分类中，大多使用binary_crossentropy损失而不是通常在多类分类中使用\n",
    "# 的categorical_crossentropy损失函数。这可能看起来不合理，但因为每个输出节点都是独立的\n",
    "# ，选择二元损失，并将网络输出建模为每个标签独立的bernoulli分布。\n",
    "model = Model(inputs=base_model.input, outputs=predictions)\n",
    "model.compile(optimizer=rms ,\n",
    "            loss='binary_crossentropy',\n",
    "            metrics=['accuracy'])\n",
    "\n",
    "model.summary()\n",
    "\n",
    "from keras.utils import plot_model\n",
    "plot_model(model, to_file='model.png')\n",
    "\n",
    "# base_model = resnet50 (weights='imagenet', include_top=False)\n",
    "\n",
    "# # add a global spatial average pooling layer\n",
    "# x = base_model.output\n",
    "# x = GlobalAveragePooling2D()(x)\n",
    "# # add a fully-connected layer\n",
    "# x = Dense(1024, activation='relu')(x)\n",
    "# # and a logistic layer -- let's say we have 7 classes\n",
    "# predictions = Dense(7, activation='softmax')(x) \n",
    "# model = Model(inputs=base_model.input, outputs=predictions)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.9454877  0.90882546 0.5331726  0.8134395  0.76214784]]\n"
     ]
    }
   ],
   "source": [
    "img_path = \"../pics/woodgirl.jpg\"\n",
    "x = preprocess_img(img_path)\n",
    "preds = model.predict(x)\n",
    "print(preds)\n",
    "# predict_top_labels(preds,3)\n",
    "# after fitting the model and calculate predictions you have to manually assign the \n",
    "# class name to output number without using imported decode_predictions\n",
    "# https://stackoverflow.com/questions/49259361/valueerror-decode-predictions-expects-a-batch-of-predictions-i-e-a-2d-array\n",
    "#https://stackoverflow.com/questions/49259361/valueerror-decode-predictions-expects-a-batch-of-predictions-i-e-a-2d-array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def my_predict(img_path):\n",
    "    print(img_path)\n",
    "    x = preprocess_img(img_path)\n",
    "    preds = model.predict(x)\n",
    "    print(preds)\n",
    "#     predict_top_labels(preds,3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# x_train_batch, y_train_batch = tf.train.shuffle_batch(\n",
    "#     tensors=[data.train.images, data.train.labels.astype(np.int32)],\n",
    "#     batch_size=batch_size,\n",
    "#     capacity=capacity,\n",
    "#     min_after_dequeue=min_after_dequeue,\n",
    "#     enqueue_many=enqueue_many,\n",
    "#     num_threads=8)\n",
    "# load tfrecord to dataset\n",
    "import tensorflow as tf\n",
    "# dataset=tf.contrib.data.TFRecordDataset(\"train.bin\").repeat()\n",
    "\n",
    "# dataset=tf.data.TFRecordDataset(\"train.bin\").repeat()\n",
    "filenames = [\"../../../Dataset/miml-image-data/processed/train.tfrecord\", \n",
    "             \"../../../Dataset/miml-image-data/processed/test.tfrecord\"]\n",
    "# my_dataset = tf.contrib.data.TFRecordDataset(filenames)\n",
    "train_dataset = tf.data.TFRecordDataset([filenames[0]])\n",
    "test_dataset = tf.data.TFRecordDataset([filenames[1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<TFRecordDataset shapes: (), types: tf.string>\n",
      "<TFRecordDataset shapes: (), types: tf.string>\n"
     ]
    }
   ],
   "source": [
    "print(train_dataset)\n",
    "print(test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'os' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-1-e004168feecb>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0mbase_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"../pics/nature\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mvisit_dir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbase_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-1-e004168feecb>\u001b[0m in \u001b[0;36mvisit_dir\u001b[0;34m(path)\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvisit_dir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m     \u001b[0mli\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlistdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mli\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m         \u001b[0mpathname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m         \u001b[0mmy_predict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpathname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'os' is not defined"
     ]
    }
   ],
   "source": [
    "import os\n",
    "def visit_dir(path):\n",
    "    li = os.listdir(path)\n",
    "    for p in li:\n",
    "        pathname = os.path.join(path,p)\n",
    "        my_predict(pathname)\n",
    "#         my_predict(pathname)\n",
    "    \n",
    "base_path = \"../pics/nature\"\n",
    "visit_dir(base_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_tf(example):\n",
    "    dics = {}\n",
    "    dics['label_1'] = tf.FixedLenFeature(shape=[5],dtype=tf.int64)\n",
    "    dics['label_2'] = tf.FixedLenFeature(shape=[5], dtype=tf.int64)\n",
    "    dics['label_3'] = tf.FixedLenFeature(shape=[5], dtype=tf.int64)\n",
    "    dics['label_4'] = tf.FixedLenFeature(shape=[5], dtype=tf.int64)\n",
    "    dics['label_5'] = tf.FixedLenFeature(shape=[5], dtype=tf.int64)\n",
    "    dics['raw_image'] = tf.FixedLenFeature(shape=[],dtype=tf.string)\n",
    "    parsed = tf.parse_single_example(example,features=dics)\n",
    "    image = tf.decode_raw(parsed['raw_image'],out_type=tf.uint8)\n",
    "    image = tf.reshape(image,shape=[224,224,3])\n",
    "    image = tf.image.per_image_standardization(image)\n",
    "    label_1 = parsed['label_1']\n",
    "    label_2 = parsed['label_2']\n",
    "    label_3 = parsed['label_3']\n",
    "    label_4 = parsed['label_4']\n",
    "    label_5 = parsed['label_5']\n",
    "\n",
    "    label_1 = tf.cast(label_1,tf.int32)\n",
    "    label_2 = tf.cast(label_2, tf.int32)\n",
    "    label_3 = tf.cast(label_3, tf.int32)\n",
    "    label_4 = tf.cast(label_4, tf.int32)\n",
    "    label_5 = tf.cast(label_5, tf.int32)\n",
    "\n",
    "    return image,label_1,label_2,label_3,label_4,label_5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# https://stackoverflow.com/questions/42184863/how-do-you-make-tensorflow-keras-fast-with-a-tfrecord-dataset\n",
    "from keras import backend as K\n",
    "sess = tf.Session()\n",
    "K.set_session(sess)\n",
    "x_train, x_test, y_train, y_test = load()\n",
    "\n",
    "x_train = x_train.astype('float32')\n",
    "x_test  = x_test.astype('float32')\n",
    "\n",
    "x_train /= 255\n",
    "x_test /= 255\n",
    "# train_dataset = train_dataset.map(parse_tf)\n",
    "# train_dataset = train_dataset.batch(16).repeat(1)\n",
    "# train_iter = train_dataset.make_one_shot_iterator()\n",
    "\n",
    "# test_dataset = test_dataset.map(parse_tf)\n",
    "# test_dataset = test_dataset.batch(16).repeat(1)\n",
    "# test_iter = test_dataset.make_one_shot_iterator()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Error when checking input: expected input_1 to have shape (244, 244, 3) but got array with shape (3, 100, 100)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-13-12bdd358c82c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      3\u001b[0m history = model.fit(x_train,y_train,\n\u001b[1;32m      4\u001b[0m                 \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m                 batch_size=batch_size)\n\u001b[0m\u001b[1;32m      6\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0maccuracy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_test\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Test loss: '\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)\u001b[0m\n\u001b[1;32m   1628\u001b[0m             \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1629\u001b[0m             \u001b[0mclass_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclass_weight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1630\u001b[0;31m             batch_size=batch_size)\n\u001b[0m\u001b[1;32m   1631\u001b[0m         \u001b[0;31m# Prepare validation data.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1632\u001b[0m         \u001b[0mdo_validation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36m_standardize_user_data\u001b[0;34m(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)\u001b[0m\n\u001b[1;32m   1474\u001b[0m                                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_feed_input_shapes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1475\u001b[0m                                     \u001b[0mcheck_batch_axis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1476\u001b[0;31m                                     exception_prefix='input')\n\u001b[0m\u001b[1;32m   1477\u001b[0m         y = _standardize_input_data(y, self._feed_output_names,\n\u001b[1;32m   1478\u001b[0m                                     \u001b[0moutput_shapes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36m_standardize_input_data\u001b[0;34m(data, names, shapes, check_batch_axis, exception_prefix)\u001b[0m\n\u001b[1;32m    121\u001b[0m                             \u001b[0;34m': expected '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnames\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m' to have shape '\u001b[0m \u001b[0;34m+\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    122\u001b[0m                             \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m' but got array with shape '\u001b[0m \u001b[0;34m+\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 123\u001b[0;31m                             str(data_shape))\n\u001b[0m\u001b[1;32m    124\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    125\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Error when checking input: expected input_1 to have shape (244, 244, 3) but got array with shape (3, 100, 100)"
     ]
    }
   ],
   "source": [
    "batch_size = 32\n",
    "epochs = 200\n",
    "history = model.fit(x_train,y_train,\n",
    "                epochs=epochs,\n",
    "                batch_size=batch_size)\n",
    "loss,accuracy = model.evaluate(x_test,y_test)\n",
    "print('Test loss: ',loss)                \n",
    "print('Test accuracy: ',accuracy)\n",
    "print(\"多标签模型训练完毕\")\n",
    "model_path = \"../model_path/multi-label.h5\"\n",
    "model.save(model_path)\n",
    "scores = model.evaluate(x_test, y_test, verbose=1)\n",
    "print('Test loss:', scores[0])\n",
    "print('Test accuracy:', scores[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
